# -*- coding: utf-8 -*-
"""
@date: 2021/8/31 13:49
@file: expand_tail.py
@author: lilong
@desc: 
"""

"""
expand_dims(x, dim=-1)：在下标为dim的轴上增加一维
tile(x, n)：将x在各个维度上重复n次，x为张量，n为与x维度数目相同的列表
"""

from keras import backend as K
import numpy as np

tf_session = K.get_session()

x = np.array([[2, 3], [2, 3], [2, 3]])  # 大小为3x2
print("x shape:", x.shape)

y1 = K.expand_dims(x, 0)
y2 = K.expand_dims(x, 1)
y3 = K.expand_dims(x, 2)
y4 = K.expand_dims(x, -1)
print("y1 shape:", y1.shape)
print("y2 shape:", y2.shape)
print("y3 shape:", y3.shape)
print("y4 shape:", y4.shape)

# 大小为3x2
print("----------------------------")
x = np.array([[1, 2], [3, 4], [5, 6]])
a = K.tile(x, [2, 1])
b = K.tile(x, [1, 2])
print(x, x.shape)
print(a.eval(session=tf_session))
print(a, a.shape)
print(b, b.shape)
print(b.eval(session=tf_session))

